# This script runs Spark SQL queries over a URL-level representation of Twitter shares to 
# generate the input used to calculate the Twitter-based URL and domain scores in the manuscript.

import os

os.environ["PYSPARK_PYTHON"] = "python3"
os.environ["PYSPARK_DRIVER_PYTHON"] = "python3"
import findspark
findspark.init()
findspark.os.environ["PYSPARK_PYTHON"] = "python3"
findspark.os.environ["PYSPARK_DRIVER_PYTHON"] = "python3"

import pyspark.sql.types as T
import pyspark.sql.functions as F
from pyspark.sql import SparkSession

from pathlib import Path
import pandas as pd
import json


PROJECT_DIR = Path("/net/data/twitter-partisanship/")
OUTPUT_FOLDER = "reference_tables/"
(PROJECT_DIR / OUTPUT_FOLDER).mkdir(parents=True, exist_ok=True)
HDFS_PROJECT_DIR = "hdfs://megatron.ccs.neu.edu/user/smccabe/curation_bubbles"
MIN_DATE = "2016-12-31"
MAX_DATE = "2019-01-01"

URL_THRESHOLD = 10
DOMAIN_THRESHOLD = 50

EXCLUDED_DOMAINS = (
    "twitter.com",
    "wordpress.com",
    "blogspot.com"
    "instagram.com",
    "facebook.com",
    "fb.me",
    "youtube.com",
    "youtu.be",
    "linkedin.com",
    "reddit.com",
    "bit.ly",
    "amazon.com",
    "etsy.com",
)

NO_PARTY_REG_STATES = ('MT', 'VT', 'MS', 'ND', 'AR', 'AL', 'WI', 'HI', 'MI', 'IN', 'MO', 'MN', 'SC', 'VA', 'WA', 'OH', 'TN', 'GA', 'IL', 'TX')
CLOSED_PARTY_STATES = ('CT', 'DE', 'FL', 'KS', 'KY', 'ME', 'MD', 'DC', 'NE', 'NM', 'NY', 'PA', 'WY')

if __name__ == "__main__":
    spark = (
        SparkSession.builder.appName("URL Partisanship")
        .config("PYSPARK_PYTHON", "python3")
        .config("PYSPARK_DRIVER_PYTHON", "python3")
        .config("spark.excludeOnFailure.enabled", "true")
        .config("spark.excludeOnFailure.killExcludedExecutors", "true")
        .config("spark.excludeOnFailure.application.fetchFailure.enabled", "true")
        .config("spark.sql.legacy.timeParserPolicy", "LEGACY")
        .getOrCreate()
    )

print("loading data...")
df = (
    spark.read.csv(
        "hdfs://megatron.ccs.neu.edu/user/smccabe/fake_news/*/*",
        header='true',
        inferSchema='true',
        sep="\t",
    ).withColumnRenamed("parsed_domain", "domain")
    .withColumnRenamed("parsed_url", "url")
    .withColumn("created_at", F.to_timestamp(F.col('created_at'), 'EEE MMM dd HH:mm:ss ZZZZZ yyyy'))
    .where(F.col("created_at") > F.lit(MIN_DATE))
    .where(F.col("created_at") < F.lit(MAX_DATE))
)

# drop 'bare-domain' shares
df = df.withColumn('page', F.regexp_extract(F.col('url'), r"[\w]+:\/\/[^\/]*\/(.*)", 1))
df = df.where(F.length(F.col('page')) > 0).drop('page')

vf = spark.read.csv("/user/lab-lazer/TSmart-cleaner-Oct2017-rawFormat.csv", sep=",", header=True, inferSchema=True)
vf = vf.withColumn("twProfileID", F.col("twProfileID").cast(T.StringType())).selectExpr(
    """twProfileID AS userid""", 
    """tsmart_partisan_score""",
    """ROUND(tsmart_partisan_score) AS party_score_rounded""", 
    """CASE
        WHEN tsmart_partisan_score < 35 THEN 1 
        WHEN tsmart_partisan_score > 65 THEN -1 
        ELSE 0 
        END AS party_score_trichotomized""",
    """CASE
        WHEN tsmart_partisan_score < 35 THEN 1 
        WHEN tsmart_partisan_score > 65 THEN -1 
        ELSE NULL
        END AS party_score_dichotomized""",
    """(50 - ROUND(tsmart_partisan_score))/50 AS party_score_continuous""", 
    f"""CASE
        WHEN tsmart_state IN {NO_PARTY_REG_STATES} THEN NULL
        WHEN vf_party = "Republican" THEN 1
        WHEN vf_party = "Democrat" THEN -1
        ELSE NULL
        END AS party_reg_dichotomized
    """,
    f"""CASE
        WHEN tsmart_state IN {NO_PARTY_REG_STATES} THEN NULL
        WHEN vf_party = "Republican" THEN 1
        WHEN vf_party = "Democrat" THEN -1
        ELSE 0
        END AS party_reg_trichotomized
    """,
    )

df = df.join(vf, how="inner", on="userid")

# load in the output of the political classifier and mirror it to Spark
with open(PROJECT_DIR / "url_scores_blurbs_revised_politics.jsonl", "r") as fin:
    politics_labels_rows = [json.loads(x.strip()) for x in fin.readlines()]

politics_labels = pd.DataFrame.from_records(politics_labels_rows)
politics_labels = politics_labels[~politics_labels['url'].str.contains(" | ", regex=False)]

politics_labels['has_blurb'] = ((pd.notnull(politics_labels['blurb'])) & (politics_labels['blurb'] != '')).astype(int)
politics_labels.loc[politics_labels['headline'] == '', 'headline'] = pd.NA
politics_labels.loc[politics_labels['blurb'] == '', 'blurb'] = pd.NA
politics_labels['politics_label'].fillna(0, inplace=True)

politics = spark.createDataFrame(politics_labels)
df = df.join(politics, how="left", on='url')

df = df.withColumn("party_reg_state", F.when(df.state.isin(list(NO_PARTY_REG_STATES)), 0).otherwise(1))

df.createOrReplaceTempView('all_data')

url = spark.sql(
    f"""
    SELECT
        url as url,
        FIRST(domain) as domain,
        FIRST(headline) AS headline,
        FIRST(blurb) AS blurb,
        FIRST(politics_score) AS politics_score,
        FIRST(politics_label) AS politics_label,
        MIN(created_at) as date,
        AVG(party_score_dichotomized) AS url_score_orig,
        AVG(party_score_continuous) AS url_score_continuous,
        AVG(party_reg_dichotomized) AS url_score_reg,
        AVG(party_reg_trichotomized) AS url_score_reg_ind,
        COUNT(party_score_rounded) AS num_shares,
        COUNT_IF(party_score_trichotomized = -1) AS num_dem_shares,
        COUNT_IF(party_score_trichotomized = 0) AS num_ind_shares,
        COUNT_IF(party_score_trichotomized = 1) AS num_rep_shares
    FROM
        all_data
    GROUP BY
        url
    HAVING
        num_shares > {URL_THRESHOLD} AND
        domain NOT IN {EXCLUDED_DOMAINS}
    """
)

sim_input = (
    df.where(F.col("politics_label")==F.lit(1))
    .join(url.select('url'), how="inner", on="url")
    .select("url", "domain", "party_score_rounded")
    .withColumn("value", F.format_string('n_%03.0f', 'party_score_rounded'))
    .groupBy("url", "domain", "value")
    .count()
    .groupBy("url", "domain")
    .pivot("value")
    .sum("count")
    .fillna(0)
)
url.toPandas().to_csv(PROJECT_DIR / "reference_tables/url_reference_table.tsv", sep="\t", index=False)
sim_input.toPandas().to_csv(PROJECT_DIR / "reference_tables/simulation_input.tsv", sep="\t", index=False)

domain = spark.sql(
    f"""
    SELECT
        domain AS domain,
        AVG(politics_label) as pct_political,
        COUNT(DISTINCT userid) AS num_users,
        COUNT(*) AS num_shares,
        COUNT_IF(party_score_trichotomized = -1) AS num_dem_shares,
        COUNT_IF(party_score_trichotomized = 0) AS num_ind_shares,
        COUNT_IF(party_score_trichotomized = 1) AS num_rep_shares,
        AVG(party_score_dichotomized) AS domain_score_orig,
        AVG(party_score_continuous) AS domain_score_continuous,
        AVG(party_reg_dichotomized) AS domain_score_reg,
        AVG(party_reg_trichotomized) AS domain_score_reg_ind,
        AVG(CASE WHEN politics_label = 1 THEN party_score_dichotomized ELSE NULL END) AS domain_score_orig_political,
        AVG(CASE WHEN politics_label = 1 THEN party_score_continuous ELSE NULL END) AS domain_score_continuous_political,
        AVG(CASE WHEN politics_label = 1 THEN party_reg_dichotomized ELSE NULL END) AS domain_score_reg_political,
        AVG(CASE WHEN politics_label = 1 THEN party_reg_trichotomized ELSE NULL END) AS domain_score_reg_ind_political
    FROM
        all_data
    WHERE
        domain NOT IN {EXCLUDED_DOMAINS}
    GROUP BY
        domain
    HAVING
        num_shares > {DOMAIN_THRESHOLD}
    """
)
domain.toPandas().to_csv(PROJECT_DIR / "reference_tables/domain_reference_table.tsv", sep="\t", index=False)

url_subset = (
    url.drop('num_dem_shares', 
             'num_ind_shares', 
             'num_rep_shares', 
             'headline', 
             'politics_label',
             'politics_score',
             'blurb', 
             'date')
    .withColumnRenamed('num_shares', 'url_shares')
)
domain_subset = (
    domain.drop('num_dem_shares', 
                'num_ind_shares', 
                'num_rep_shares', 
                'pct_political',
                'num_users')
    .withColumnRenamed('num_shares', 'domain_shares')
)
df = df.join(url_subset, how='left', on=['url', 'domain'])
df = df.join(domain_subset, how='left', on='domain')

df.createOrReplaceTempView('all_data')

user = spark.sql(
    f"""
    SELECT
        CAST(userid AS STRING) AS userid,
        AVG(url_score_orig) AS url_score_orig,
        AVG(url_score_continuous) AS url_score_continuous,
        AVG(url_score_reg) AS url_score_reg,
        AVG(url_score_reg_ind) AS url_score_reg_ind,
        AVG(CASE WHEN politics_label = 1 THEN url_score_orig ELSE NULL END) AS url_score_orig_political,
        AVG(CASE WHEN politics_label = 1 THEN url_score_continuous ELSE NULL END) AS url_score_continuous_political,
        AVG(CASE WHEN politics_label = 1 THEN url_score_reg ELSE NULL END) AS url_score_reg_political,
        AVG(CASE WHEN politics_label = 1 THEN url_score_reg_ind ELSE NULL END) AS url_score_reg_political,
        AVG(domain_score_orig) AS domain_score_orig,
        AVG(domain_score_continuous) AS domain_score_continuous,
        AVG(domain_score_reg) AS domain_score_reg,
        AVG(domain_score_reg_ind) AS domain_score_reg_ind,
        AVG(CASE WHEN politics_label = 1 THEN domain_score_orig_political END) AS domain_score_orig_political,
        AVG(CASE WHEN politics_label = 1 THEN domain_score_continuous_political END) AS domain_score_continuous_political,
        AVG(CASE WHEN politics_label = 1 THEN domain_score_reg_political END) AS domain_score_reg_political,
        AVG(CASE WHEN politics_label = 1 THEN domain_score_reg_ind_political END) AS domain_score_reg_ind_political
    FROM
        all_data
    WHERE
        url_shares > {URL_THRESHOLD} AND
        domain_shares > {DOMAIN_THRESHOLD} AND
        domain NOT IN {EXCLUDED_DOMAINS}
    GROUP BY
        userid
        """
)

user.toPandas().to_csv(PROJECT_DIR / "reference_tables/user_reference_table.tsv", sep="\t", index=False)
spark.stop()
